#region <<版权版本注释>>

// ----------------------------------------------------------------
// Copyright ©2021-Present ZhaiFanhua All Rights Reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.
// FileName:DataPermissionFilter
// Guid:cc2b3c4d-5e6f-7890-abcd-ef12345678c1
// Author:zhaifanhua
// Email:me@zhaifanhua.com
// CreateTime:2025/10/31 7:35:00
// ----------------------------------------------------------------

#endregion <<版权版本注释>>

using System.Linq.Expressions;
using System.Reflection;
using XiHan.BasicApp.Rbac.DataPermissions.Enums;
using XiHan.BasicApp.Rbac.Repositories.Abstractions;

namespace XiHan.BasicApp.Rbac.DataPermissions.Filters;

/// <summary>
/// 数据权限过滤器实现
/// </summary>
public class DataPermissionFilter : IDataPermissionFilter
{
    private readonly IUserRepository _userRepository;
    private readonly IDepartmentRepository _departmentRepository;

    /// <summary>
    /// 构造函数
    /// </summary>
    public DataPermissionFilter(
        IUserRepository userRepository,
        IDepartmentRepository departmentRepository)
    {
        _userRepository = userRepository;
        _departmentRepository = departmentRepository;
    }

    /// <summary>
    /// 应用数据权限过滤
    /// </summary>
    public IQueryable<TEntity> ApplyFilter<TEntity>(
        IQueryable<TEntity> query,
        RbacIdType userId,
        DataPermissionScope scope) where TEntity : class
    {
        var expression = BuildFilterExpressionAsync<TEntity>(userId, scope).GetAwaiter().GetResult();
        if (expression != null)
        {
            query = query.Where(expression);
        }
        return query;
    }

    /// <summary>
    /// 构建数据权限表达式
    /// </summary>
    public async Task<Expression<Func<TEntity, bool>>?> BuildFilterExpressionAsync<TEntity>(
        RbacIdType userId,
        DataPermissionScope scope,
        string departmentField = "DepartmentId",
        string creatorField = "CreatedBy") where TEntity : class
    {
        switch (scope)
        {
            case DataPermissionScope.All:
                // 全部数据权限，不需要过滤
                return null;

            case DataPermissionScope.SelfOnly:
                // 仅本人数据权限
                return BuildSelfOnlyExpression<TEntity>(userId, creatorField);

            case DataPermissionScope.DepartmentOnly:
                // 仅本部门数据权限
                return await BuildDepartmentOnlyExpressionAsync<TEntity>(userId, departmentField);

            case DataPermissionScope.DepartmentAndChildren:
                // 本部门及子部门数据权限
                return await BuildDepartmentAndChildrenExpressionAsync<TEntity>(userId, departmentField);

            case DataPermissionScope.Custom:
                // 自定义数据权限，需要实现自定义过滤器
                return null;

            default:
                return null;
        }
    }

    /// <summary>
    /// 检查用户是否有数据访问权限
    /// </summary>
    public async Task<bool> HasPermissionAsync(
        RbacIdType userId,
        RbacIdType? targetUserId,
        RbacIdType? targetDepartmentId,
        DataPermissionScope scope)
    {
        switch (scope)
        {
            case DataPermissionScope.All:
                return true;

            case DataPermissionScope.SelfOnly:
                return targetUserId == userId;

            case DataPermissionScope.DepartmentOnly:
                if (!targetDepartmentId.HasValue)
                {
                    return false;
                }

                var userDepartmentIds = await _userRepository.GetUserDepartmentIdsAsync(userId);
                return userDepartmentIds.Contains(targetDepartmentId.Value);

            case DataPermissionScope.DepartmentAndChildren:
                if (!targetDepartmentId.HasValue)
                {
                    return false;
                }

                var allDepartmentIds = await GetUserAllDepartmentIdsAsync(userId);
                return allDepartmentIds.Contains(targetDepartmentId.Value);

            default:
                return false;
        }
    }

    #region 私有方法

    /// <summary>
    /// 构建仅本人数据权限表达式
    /// </summary>
    private Expression<Func<TEntity, bool>>? BuildSelfOnlyExpression<TEntity>(
        RbacIdType userId,
        string creatorField) where TEntity : class
    {
        var entityType = typeof(TEntity);
        var property = entityType.GetProperty(creatorField, BindingFlags.Public | BindingFlags.Instance);

        if (property == null || (property.PropertyType != typeof(RbacIdType) && property.PropertyType != typeof(RbacIdType?)))
        {
            return null;
        }

        var parameter = Expression.Parameter(entityType, "x");
        var propertyAccess = Expression.Property(parameter, property);

        Expression comparison;
        if (property.PropertyType == typeof(RbacIdType?))
        {
            // 处理可空类型
            var hasValue = Expression.Property(propertyAccess, "HasValue");
            var value = Expression.Property(propertyAccess, "Value");
            var equals = Expression.Equal(value, Expression.Constant(userId));
            comparison = Expression.AndAlso(hasValue, equals);
        }
        else
        {
            comparison = Expression.Equal(propertyAccess, Expression.Constant(userId));
        }

        return Expression.Lambda<Func<TEntity, bool>>(comparison, parameter);
    }

    /// <summary>
    /// 构建仅本部门数据权限表达式
    /// </summary>
    private async Task<Expression<Func<TEntity, bool>>?> BuildDepartmentOnlyExpressionAsync<TEntity>(
        RbacIdType userId,
        string departmentField) where TEntity : class
    {
        var userDepartmentIds = await _userRepository.GetUserDepartmentIdsAsync(userId);
        if (!userDepartmentIds.Any())
        {
            return null;
        }

        return BuildDepartmentExpression<TEntity>(userDepartmentIds, departmentField);
    }

    /// <summary>
    /// 构建本部门及子部门数据权限表达式
    /// </summary>
    private async Task<Expression<Func<TEntity, bool>>?> BuildDepartmentAndChildrenExpressionAsync<TEntity>(
        RbacIdType userId,
        string departmentField) where TEntity : class
    {
        var allDepartmentIds = await GetUserAllDepartmentIdsAsync(userId);
        if (!allDepartmentIds.Any())
        {
            return null;
        }

        return BuildDepartmentExpression<TEntity>(allDepartmentIds, departmentField);
    }

    /// <summary>
    /// 构建部门表达式
    /// </summary>
    private Expression<Func<TEntity, bool>>? BuildDepartmentExpression<TEntity>(
        List<RbacIdType> departmentIds,
        string departmentField) where TEntity : class
    {
        var entityType = typeof(TEntity);
        var property = entityType.GetProperty(departmentField, BindingFlags.Public | BindingFlags.Instance);

        if (property == null || (property.PropertyType != typeof(RbacIdType) && property.PropertyType != typeof(RbacIdType?)))
        {
            return null;
        }

        var parameter = Expression.Parameter(entityType, "x");
        var propertyAccess = Expression.Property(parameter, property);

        Expression comparison;
        if (property.PropertyType == typeof(RbacIdType?))
        {
            // 处理可空类型
            var hasValue = Expression.Property(propertyAccess, "HasValue");
            var value = Expression.Property(propertyAccess, "Value");
            var contains = Expression.Call(
                typeof(Enumerable),
                nameof(Enumerable.Contains),
                [typeof(RbacIdType)],
                Expression.Constant(departmentIds),
                value);
            comparison = Expression.AndAlso(hasValue, contains);
        }
        else
        {
            comparison = Expression.Call(
                typeof(Enumerable),
                nameof(Enumerable.Contains),
                [typeof(RbacIdType)],
                Expression.Constant(departmentIds),
                propertyAccess);
        }

        return Expression.Lambda<Func<TEntity, bool>>(comparison, parameter);
    }

    /// <summary>
    /// 获取用户的所有部门ID（包括子部门）
    /// </summary>
    private async Task<List<RbacIdType>> GetUserAllDepartmentIdsAsync(RbacIdType userId)
    {
        var userDepartmentIds = await _userRepository.GetUserDepartmentIdsAsync(userId);
        var allDepartmentIds = new List<RbacIdType>(userDepartmentIds);

        foreach (var departmentId in userDepartmentIds)
        {
            var childDepartmentIds = await GetChildDepartmentIdsAsync(departmentId);
            allDepartmentIds.AddRange(childDepartmentIds);
        }

        return allDepartmentIds.Distinct().ToList();
    }

    /// <summary>
    /// 递归获取子部门ID
    /// </summary>
    private async Task<List<RbacIdType>> GetChildDepartmentIdsAsync(RbacIdType departmentId)
    {
        var result = new List<RbacIdType>();
        var children = await _departmentRepository.GetChildrenAsync(departmentId);

        foreach (var child in children)
        {
            result.Add(child.BasicId);
            var subChildren = await GetChildDepartmentIdsAsync(child.BasicId);
            result.AddRange(subChildren);
        }

        return result;
    }

    #endregion
}